{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e85xylySwifV"
      },
      "source": [
        "### Download model\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "JHxRAEhQwqNG",
        "outputId": "b3c1f7d8-9c08-4985-e2b4-4a24845265c3"
      },
      "outputs": [],
      "source": [
        "!wget https://storage.googleapis.com/dm-tapnet/tapnext/bootstapnext_ckpt.npz\n",
        "!wget https://storage.googleapis.com/dm-tapnet/tapnext/tapnext_ckpt.npz"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rIx6RqUHQERV"
      },
      "source": [
        "### Download dataset "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mz_zXfrhQERW",
        "outputId": "5af24fbd-fb4c-4549-96a3-e483bb7bc738"
      },
      "outputs": [],
      "source": [
        "!wget --no-check-certificate https://storage.googleapis.com/dm-tapnet/tapvid_davis.zip\n",
        "!unzip tapvid_davis.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RHt3rGLLxfWs"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "import torchvision"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-gBuXTWqxuMV",
        "outputId": "42ca4293-2496-4a03-ced5-a02d9253c721"
      },
      "outputs": [],
      "source": [
        "torch.__version__, torchvision.__version__"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gjdq8wLcQWd7",
        "outputId": "8c5ed594-021c-4c67-e0df-949151d1247f"
      },
      "outputs": [],
      "source": [
        "!pip install -q git+https://github.com/google-deepmind/tapnet.git"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "mrQ_uQDeQee-",
        "outputId": "ce0d3868-7b82-4e8e-88ae-fd5ae960f95b"
      },
      "outputs": [],
      "source": [
        "!pip install -q git+https://github.com/google-deepmind/recurrentgemma.git@main"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NIZQxfcgyFM9",
        "outputId": "8cd10428-69a7-4c39-f3f9-eec9c07ed108"
      },
      "outputs": [],
      "source": [
        "!pip install \"numpy\u003c2.1.0\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gju7QkZH2XLL"
      },
      "outputs": [],
      "source": [
        "import tqdm"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "hICS3HPqcxU_",
        "outputId": "464232db-ef35-461d-bec1-0cac726bf7f4"
      },
      "outputs": [],
      "source": [
        "from tapnet import evaluation_datasets\n",
        "\n",
        "davis_dataset = evaluation_datasets.create_davis_dataset(\n",
        "    davis_points_path='tapvid_davis/tapvid_davis.pkl',\n",
        "    query_mode='first',\n",
        "    full_resolution=False,\n",
        "    resolution=(256, 256),\n",
        ")\n",
        "\n",
        "cached_dataset = []\n",
        "for j, batch in enumerate(davis_dataset):\n",
        "  cached_dataset.append(batch)\n",
        "  print(\n",
        "      'video id',\n",
        "      j,\n",
        "  )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uQ6Ako7IQERX"
      },
      "source": [
        "### TAPNext"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O3CTAohgPk_q"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "from tapnet.tapnext.tapnext_torch import TAPNext\n",
        "from tapnet.tapnext.tapnext_torch_utils import restore_model_from_jax_checkpoint, tracker_certainty\n",
        "import torch.nn.functional as F"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hTQpviwqQERc"
      },
      "outputs": [],
      "source": [
        "def run_eval_per_frame(\n",
        "    model,\n",
        "    batch,\n",
        "    get_trackwise_metrics=True,\n",
        "    radius=8,\n",
        "    threshold=0.5,\n",
        "    use_certainty=False,\n",
        "):\n",
        "  with torch.no_grad():\n",
        "    pred_tracks, track_logits, visible_logits, tracking_state = model(\n",
        "        video=batch['video'][:, :1], query_points=batch['query_points']\n",
        "    )\n",
        "    pred_visible = visible_logits \u003e 0\n",
        "    pred_tracks, pred_visible = [pred_tracks.cpu()], [pred_visible.cpu()]\n",
        "    pred_track_logits, pred_visible_logits = [track_logits.cpu()], [\n",
        "        visible_logits.cpu()\n",
        "    ]\n",
        "    for frame in tqdm.tqdm(range(1, batch['video'].shape[1])):\n",
        "      # ***************************************************\n",
        "      # HERE WE RUN POINT TRACKING IN PURELY ONLINE FASHION\n",
        "      # ***************************************************\n",
        "      (\n",
        "          curr_tracks,\n",
        "          curr_track_logits,\n",
        "          curr_visible_logits,\n",
        "          tracking_state,\n",
        "      ) = model(\n",
        "          video=batch['video'][:, frame : frame + 1],\n",
        "          state=tracking_state,\n",
        "      )\n",
        "      curr_visible = curr_visible_logits \u003e 0\n",
        "      # ***************************************************\n",
        "      pred_tracks.append(curr_tracks.cpu())\n",
        "      pred_visible.append(curr_visible.cpu())\n",
        "      pred_track_logits.append(curr_track_logits.cpu())\n",
        "      pred_visible_logits.append(curr_visible_logits.cpu())\n",
        "    tracks = torch.cat(pred_tracks, dim=1).transpose(1, 2)\n",
        "    pred_visible = torch.cat(pred_visible, dim=1).transpose(1, 2)\n",
        "    track_logits = torch.cat(pred_track_logits, dim=1).transpose(1, 2)\n",
        "    visible_logits = torch.cat(pred_visible_logits, dim=1).transpose(1, 2)\n",
        "\n",
        "    pred_certainty = tracker_certainty(tracks, track_logits, radius)\n",
        "    pred_visible_and_certain = (\n",
        "        F.sigmoid(visible_logits) * pred_certainty\n",
        "    ) \u003e threshold\n",
        "\n",
        "    if use_certainty:\n",
        "      occluded = ~(pred_visible_and_certain.squeeze(-1))\n",
        "    else:\n",
        "      occluded = ~(pred_visible.squeeze(-1))\n",
        "\n",
        "  scalars = evaluation_datasets.compute_tapvid_metrics(\n",
        "      batch['query_points'].cpu().numpy(),\n",
        "      batch['occluded'].cpu().numpy(),\n",
        "      batch['target_points'].cpu().numpy(),\n",
        "      occluded.numpy() + 0.0,\n",
        "      tracks.numpy()[..., ::-1],\n",
        "      query_mode='first',\n",
        "      get_trackwise_metrics=get_trackwise_metrics,\n",
        "  )\n",
        "  return (\n",
        "      tracks.numpy()[..., ::-1],\n",
        "      occluded,\n",
        "      {k: v.sum(0) for k, v in scalars.items()},\n",
        "  )\n",
        "\n",
        "\n",
        "# @title Function for raw data to the input format {form-width: \"25%\"}\n",
        "def deterministic_eval(cached_dataset, strided=False):\n",
        "  if not strided:\n",
        "    for sample in tqdm.tqdm(cached_dataset, disable=True):\n",
        "      batch = sample['davis'].copy()\n",
        "      # batch['video'] = (batch['video'] + 1) / 2\n",
        "      batch['visible'] = np.logical_not(batch['occluded'])[..., None]\n",
        "      batch['padding'] = np.ones(\n",
        "          batch['query_points'].shape[:2], dtype=np.bool_\n",
        "      )\n",
        "      batch['loss_mask'] = np.ones(\n",
        "          batch['target_points'].shape[:3] + (1,), dtype=np.float32\n",
        "      )\n",
        "      batch['appearance'] = np.ones(\n",
        "          batch['target_points'].shape[:3] + (1,), dtype=np.float32\n",
        "      )\n",
        "\n",
        "      yield batch\n",
        "  else:\n",
        "    for sample in tqdm.tqdm(cached_dataset):\n",
        "      batch = sample['davis'].copy()\n",
        "      # batch['video'] = (batch['video'] + 1) / 2\n",
        "      batch['visible'] = np.logical_not(batch['occluded'])[..., None]\n",
        "      batch['padding'] = np.ones(\n",
        "          batch['query_points'].shape[:2], dtype=np.bool_\n",
        "      )\n",
        "      batch['loss_mask'] = np.ones(\n",
        "          batch['target_points'].shape[:3] + (1,), dtype=np.float32\n",
        "      )\n",
        "      batch['appearance'] = np.ones(\n",
        "          batch['target_points'].shape[:3] + (1,), dtype=np.float32\n",
        "      )\n",
        "      backward_batch = {k: v.copy() for k, v in batch.items()}\n",
        "      for key in ['visible', 'appearance', 'loss_mask', 'target_points']:\n",
        "        backward_batch[key] = np.flip(backward_batch[key], axis=2)\n",
        "      backward_batch['video'] = np.flip(backward_batch['video'], axis=1)\n",
        "      backward_queries = (\n",
        "          backward_batch['video'].shape[1]\n",
        "          - backward_batch['query_points'][..., 0]\n",
        "          - 1\n",
        "      )\n",
        "      backward_batch['query_points'][..., 0] = backward_queries\n",
        "      yield batch, backward_batch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gJpHisHSQERc"
      },
      "source": [
        "### Create the model and load checkpoint"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pnTma4NKQERc"
      },
      "outputs": [],
      "source": [
        "model = TAPNext(image_size=(256, 256))\n",
        "ckpt_path = 'bootstapnext_ckpt.npz'\n",
        "model = restore_model_from_jax_checkpoint(model, ckpt_path)\n",
        "model.cuda()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CukzSyYSQERd"
      },
      "source": [
        "### Run inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 477
        },
        "id": "YNf98GICtJMa",
        "outputId": "05da0eef-b914-4d1b-f740-5152f47ddb4c"
      },
      "outputs": [],
      "source": [
        "standard_eval_scalars_list = []\n",
        "preds = []\n",
        "for batch in deterministic_eval(cached_dataset):\n",
        "  batch = {k: torch.from_numpy(v).cuda().float() for k, v in batch.items()}\n",
        "  with torch.amp.autocast('cuda', dtype=torch.float16, enabled=True):\n",
        "    tracks, occluded, scores = run_eval_per_frame(\n",
        "        model, batch, get_trackwise_metrics=False, use_certainty=False\n",
        "    )\n",
        "  standard_eval_scalars_list.append(scores)\n",
        "  preds.append((tracks, occluded))\n",
        "\n",
        "\n",
        "print('')\n",
        "print(\n",
        "    'AJ',\n",
        "    np.mean([\n",
        "        standard_eval_scalars_list[k]['average_jaccard']\n",
        "        for k in range(len(standard_eval_scalars_list))\n",
        "    ]),\n",
        ")\n",
        "print(\n",
        "    'OA',\n",
        "    np.mean([\n",
        "        standard_eval_scalars_list[k]['occlusion_accuracy']\n",
        "        for k in range(len(standard_eval_scalars_list))\n",
        "    ]),\n",
        ")\n",
        "print(\n",
        "    'PTS',\n",
        "    np.mean([\n",
        "        standard_eval_scalars_list[k]['average_pts_within_thresh']\n",
        "        for k in range(len(standard_eval_scalars_list))\n",
        "    ]),\n",
        ")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.11"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}
